import random
from abc import abstractmethod
from typing import SupportsFloat, Any, Optional, Iterable, Union

import torch
import numpy as np
# import gym_minigrid  # Needed for Utils.make_env
# import minigrid # minigrid比gym_minigrid更加难以训练
import gymnasium as gym
from gymnasium import spaces
from gymnasium.core import ActType, ObsType
from minigrid.core.grid import Grid
from minigrid.core.mission import MissionSpace
from minigrid.core.world_object import Goal
from minigrid.minigrid_env import MiniGridEnv

from continual_rl.experiments.tasks.minigridenv import ActionGridEnv
from continual_rl.experiments.tasks.task_base import TaskBase
from continual_rl.experiments.tasks.preprocessor_base import PreprocessorBase
from continual_rl.utils.utils import Utils
from continual_rl.utils.env_wrappers import FrameStack, LazyFrames, OldInterfaceWrapper


class MiniGridToPyTorch(gym.ObservationWrapper):

    def __init__(self, env):
        # 转换观测空间
        super().__init__(env)
        old_shape = self.observation_space['image'].shape

        # MiniGrid任务观测由[0，10]范围内的整数表示
        # 具体而言，观测中3个通道是 [OBJECT_IDX, COLOR_IDX, STATE]
        # OBJECT_IDX is [0, 10], COLOR_IDX is [0, 5], and STATE is [0, 2]
        # (https://github.com/maximecb/gym-minigrid/blob/master/gym_minigrid/minigrid.py)
        self.observation_space = gym.spaces.Box(
            low=0,
            high=10,
            shape=(old_shape[-1], old_shape[0], old_shape[1]),
            dtype=np.uint8,
        )
        # self.action_space = spaces.Discrete(3)  # 只保留左转，右转，前进动作

    def observation(self, observation):
        # 将观测值转换为PyTorch张量
        # 注意：MiniGrid图像的通道顺序是[H, W, C]，而PyTorch的通道顺序是[C, H, W]
        processed_observation = torch.tensor(observation['image'])
        processed_observation = processed_observation.permute(2, 0, 1)
        return processed_observation


class MiniGridPreprocessor(PreprocessorBase):
    def __init__(self, env_spec, time_batch_size):
        self.env_spec = self._wrap_env(env_spec, time_batch_size)
        self.render_env_spec = self._wrap_env(env_spec, time_batch_size, render_mode="rgb_array")
        dummy_env, _ = Utils.make_env(self.env_spec)
        super().__init__(dummy_env.observation_space)
        self.need_render = True

    def _wrap_env(self, env_spec, time_batch_size, render_mode=None):
        """
        定义对原始环境进行包装的方式，为了兼容旧版本gym接口使用了OldInterfaceWrapper
        """
        # frame_stacked_env_spec = lambda: FrameStack(OldInterfaceWrapper(MiniGridToPyTorch(Utils.make_env(env_spec)[0])),
        #                                             time_batch_size)
        frame_stacked_env_spec = lambda: FrameStack(
            OldInterfaceWrapper(MiniGridToPyTorch(ActionGridEnv(env_spec, render_mode=render_mode))),
            time_batch_size)
        return frame_stacked_env_spec

    def preprocess(self, batched_obs):
        """
        将framestack得到的一批观测结果叠加，得到的维度是[B, C, H, W]
        """
        assert isinstance(batched_obs[0], LazyFrames), f"Observation was of unexpected type: {type(batched_obs[0])}"
        # Minigrid images are [H, W, C], so rearrange to pytorch's expectations.
        return torch.stack([obs.to_tensor() for obs in batched_obs])

    def render_episode(self, episode_observations):
        """
        将从该集收集的观察结果列表转换为视频，该视频可以保存下来以查看行为。
        """
        # 此处三个通道并不代表RGB，所以这是一个方便但不一定最佳的表示
        # TODO 更适合Minigrid的展示方式
        video = torch.tensor(np.array(episode_observations))  # Tx255x255x3
        video = video.permute(0, 3, 1, 2)
        return video.unsqueeze(0).float() / 255


class MiniGridTask(TaskBase):
    """
    MiniGrid有一个自定义的观察格式，所以有一个单独的Task类型来处理解析它
    """

    def __init__(self, task_id, action_space_id, env_spec, num_timesteps, time_batch_size, eval_mode,
                 observation_space_id):
        preprocessor = MiniGridPreprocessor(env_spec, time_batch_size)
        dummy_env, _ = Utils.make_env(preprocessor.env_spec)
        action_space = dummy_env.action_space

        super().__init__(task_id, action_space_id, preprocessor, preprocessor.env_spec, preprocessor.observation_space,
                         action_space, num_timesteps, eval_mode, render_env_spec=preprocessor.render_env_spec,
                         observation_space_id=observation_space_id)
